import pandas as pd


def read_data(path_info, path_label):
    train_info = pd.read_excel(path_info)
    train_label = pd.read_excel(path_label)

    # 更换train_info的列名
    columns_train_info = ['user_id', 'date', 'sum', 'max', 'mean', 'min']
    train_info.columns = columns_train_info

    # 更换train_label的列名
    columns_train_label = ['user_id', 'user_nature', 'error']
    train_label.columns = columns_train_label

    return train_info, train_label


# 日期标准化，并转为int类型
def get_date(value):
    tmp = value.split("/")
    #     print(tmp)
    if len(tmp[1]) == 1:
        tmp[1] = '0' + tmp[1]
    if len(tmp[2]) == 1:
        tmp[2] = '0' + tmp[2]
    return int(tmp[0] + tmp[1] + tmp[2])


def sort_list(list_res):
    # 根据日期大小进行排序
    for i in range(len(list_res)):
        flag = True
        for j in range(1, len(list_res) - i):
            if list_res[j - 1][0] > list_res[j][0]:
                list_res[j - 1], list_res[j] = list_res[j], list_res[j - 1]
                flag = False
        if flag:
            break
    # for i in range(len(list_res)):
    #     del list_res[i][0]
    return list_res


# 升序获得所有的日期数据
def get_all_date(path_info, path_label):
    train_info, train_label = read_data(path_info, path_label)
    res_date = []
    tmp = train_info['date'].unique().tolist()
    for date in tmp:
        res_date.append(get_date(date))

    res_date.sort()
    return res_date


def switch_data(train_info):

    res_all = []
    for user_id in range(len(train_info)):
        res_sum = []
        res_max = []
        res_mean = []
        res_min = []
        res_nature = []
        for date in range(len(train_info[user_id])):
            if date == 0:
                res_nature.append(train_info[user_id][date][5])
            res_sum.append(train_info[user_id][date][1])
            res_max.append(train_info[user_id][date][2])
            res_mean.append(train_info[user_id][date][3])
            res_min.append(train_info[user_id][date][4])

        res_all.append([res_sum, res_max, res_mean, res_min, res_nature])

    return res_all



def get_data(path_info, path_label):
    train_info, train_label = read_data(path_info, path_label)

    # 将train_label中的用户性质修改为数字
    train_label.loc[train_label['user_nature'] == "低压居民", "user_nature"] = 0
    train_label.loc[train_label['user_nature'] == "低压非居民", "user_nature"] = 1
    train_label.loc[train_label['user_nature'] == "高压", "user_nature"] = 2

    # 将train_label中的异常标识修改为数字
    train_label.loc[train_label['error'] == "无异常", "error"] = 0
    train_label.loc[train_label['error'] == "有异常", "error"] = 1

    # 将train_info与train_label合并 并去除掉error列
    data_train_info = pd.merge(train_info, train_label, on="user_id")
    data_train_info.drop("error", axis=1, inplace=True)

    # 将所有数据转为三维数据
    ids = data_train_info['user_id'].unique().tolist()
    data_train = []
    for user in ids:
        temp = data_train_info.loc[data_train_info['user_id'] == user]
        temp.drop('user_id', axis=1, inplace=True)
        list = []
        for i in range(len(temp)):
            tmp = []
            for j in range(6):
                if j == 0:
                    tmp.append(get_date(temp.iloc[i, j]))
                else:
                    tmp.append(temp.iloc[i, j])
            list.append(tmp)
        list = sort_list(list)
        data_train.append(list)

    # 获取每一个用户的异常标识
    data_train_label = train_label['error'].tolist()
    # 获取处理过后的用户时间序列
    date_train_info = switch_data(data_train)

    return date_train_info, data_train_label


if __name__ == '__main__':
    path_info = "../../../dataset/train_info.xlsx"
    path_label = "../../../dataset/train_label.xlsx"

    data_train_info, data_train_label = get_data(path_info, path_label)

    print(data_train_info[0])
    print(data_train_label[0])

    # date_num = []
    # for i in range(len(data_train_info)):
    #     if len(data_train_info[i]) not in date_num:
    #         date_num.append(len(data_train_info[i]))
    # print(date_num)

    # print(data_train_label)


